#pragma once

#ifndef WINML_H_
#define WINML_H_

#include "weak_buffer.h"
#include "buffer_backed_random_access_stream_reference.h"
#include "weak_single_threaded_iterable.h"

#define RETURN_HR_IF_FAILED(expression) \
  do {                                  \
    auto _hr = expression;              \
    if (FAILED(_hr)) {                  \
      return static_cast<int32_t>(_hr); \
    }                                   \
  } while (0)

#define FAIL_FAST_IF_HR_FAILED(expression)   \
  do {                                       \
    auto _hr = expression;                   \
    if (FAILED(_hr)) {                       \
      __fastfail(static_cast<int32_t>(_hr)); \
    }                                        \
  } while (0)

struct float16 {
  uint16_t value;
};

namespace Microsoft {
namespace AI {
namespace MachineLearning {
namespace Details {

class WinMLLearningModel;
class WinMLLearningModelBinding;
class WinMLLearningModelSession;
class WinMLLearningModelResults;

extern const __declspec(selectany) _Null_terminated_ wchar_t MachineLearningDll[] = L"microsoft.ai.machinelearning.dll";

template <typename T>
struct Tensor {};
template <>
struct Tensor<float> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorFloat;
};
template <>
struct Tensor<float16> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorFloat16Bit;
};
template <>
struct Tensor<int8_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt8Bit;
};
template <>
struct Tensor<uint8_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt8Bit;
};
template <>
struct Tensor<uint16_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt16Bit;
};
template <>
struct Tensor<int16_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt16Bit;
};
template <>
struct Tensor<uint32_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt32Bit;
};
template <>
struct Tensor<int32_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt32Bit;
};
template <>
struct Tensor<uint64_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorUInt64Bit;
};
template <>
struct Tensor<int64_t> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorInt64Bit;
};
template <>
struct Tensor<bool> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorBoolean;
};
template <>
struct Tensor<double> {
  using Type = ABI::Microsoft::AI::MachineLearning::ITensorDouble;
};

template <typename T>
struct TensorRuntimeClassID {};
template <>
struct TensorRuntimeClassID<float> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<float16> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<int8_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<uint8_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<uint16_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<int16_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<uint32_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<int32_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<uint64_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<int64_t> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<bool> {
  static const wchar_t* RuntimeClass_ID;
};
template <>
struct TensorRuntimeClassID<double> {
  static const wchar_t* RuntimeClass_ID;
};

__declspec(selectany) const wchar_t* TensorRuntimeClassID<float>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorFloat;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<float16>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorFloat16Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<int8_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorInt8Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<uint8_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt8Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<uint16_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt16Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<int16_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorInt16Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<uint32_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt32Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<int32_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorInt32Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<uint64_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorUInt64Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<int64_t>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorInt64Bit;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<bool>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorBoolean;
__declspec(selectany) const wchar_t* TensorRuntimeClassID<double>::RuntimeClass_ID =
  RuntimeClass_Microsoft_AI_MachineLearning_TensorDouble;

template <typename T>
struct TensorFactory {};
template <>
struct TensorFactory<float> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloatStatics;
};
template <>
struct TensorFactory<float16> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloat16BitStatics;
};
template <>
struct TensorFactory<int8_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt8BitStatics;
};
template <>
struct TensorFactory<uint8_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt8BitStatics;
};
template <>
struct TensorFactory<uint16_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt16BitStatics;
};
template <>
struct TensorFactory<int16_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt16BitStatics;
};
template <>
struct TensorFactory<uint32_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt32BitStatics;
};
template <>
struct TensorFactory<int32_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt32BitStatics;
};
template <>
struct TensorFactory<uint64_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt64BitStatics;
};
template <>
struct TensorFactory<int64_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt64BitStatics;
};
template <>
struct TensorFactory<bool> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorBooleanStatics;
};
template <>
struct TensorFactory<double> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorDoubleStatics;
};

template <typename T>
struct TensorFactory2 {};
template <>
struct TensorFactory2<float> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloatStatics2;
};
template <>
struct TensorFactory2<float16> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorFloat16BitStatics2;
};
template <>
struct TensorFactory2<int8_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt8BitStatics2;
};
template <>
struct TensorFactory2<uint8_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt8BitStatics2;
};
template <>
struct TensorFactory2<uint16_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt16BitStatics2;
};
template <>
struct TensorFactory2<int16_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt16BitStatics2;
};
template <>
struct TensorFactory2<uint32_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt32BitStatics2;
};
template <>
struct TensorFactory2<int32_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt32BitStatics2;
};
template <>
struct TensorFactory2<uint64_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorUInt64BitStatics2;
};
template <>
struct TensorFactory2<int64_t> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorInt64BitStatics2;
};
template <>
struct TensorFactory2<bool> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorBooleanStatics2;
};
template <>
struct TensorFactory2<double> {
  using Factory = ABI::Microsoft::AI::MachineLearning::ITensorDoubleStatics2;
};

template <typename T>
struct TensorFactoryIID {};
template <>
struct TensorFactoryIID<float> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<float16> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<int8_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<uint8_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<uint16_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<int16_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<uint32_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<int32_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<uint64_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<int64_t> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<bool> {
  static const GUID IID;
};
template <>
struct TensorFactoryIID<double> {
  static const GUID IID;
};

__declspec(selectany) const GUID TensorFactoryIID<float>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorFloatStatics;
__declspec(selectany) const GUID TensorFactoryIID<float16>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorFloat16BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<int8_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt8BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<uint8_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt8BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<uint16_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt16BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<int16_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt16BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<uint32_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt32BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<int32_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt32BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<uint64_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt64BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<int64_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt64BitStatics;
__declspec(selectany) const GUID TensorFactoryIID<bool>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorBooleanStatics;
__declspec(selectany) const GUID TensorFactoryIID<double>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorDoubleStatics;

template <typename T>
struct TensorFactory2IID {};
template <>
struct TensorFactory2IID<float> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<float16> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<int8_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<uint8_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<uint16_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<int16_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<uint32_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<int32_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<uint64_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<int64_t> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<bool> {
  static const GUID IID;
};
template <>
struct TensorFactory2IID<double> {
  static const GUID IID;
};

__declspec(selectany) const GUID TensorFactory2IID<float>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorFloatStatics2;
__declspec(selectany) const GUID TensorFactory2IID<float16>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorFloat16BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<int8_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt8BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<uint8_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt8BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<uint16_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt16BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<int16_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt16BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<uint32_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt32BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<int32_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt32BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<uint64_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorUInt64BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<int64_t>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorInt64BitStatics2;
__declspec(selectany) const GUID TensorFactory2IID<bool>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorBooleanStatics2;
__declspec(selectany) const GUID TensorFactory2IID<double>::IID =
  ABI::Microsoft::AI::MachineLearning::IID_ITensorDoubleStatics2;

inline HRESULT GetActivationFactory(const wchar_t* p_class_id, const IID& iid, void** factory) noexcept {
  // Fallback to OS binary if the redistributable is not present!
  auto library = LoadLibraryExW(MachineLearningDll, nullptr, 0);
  if (library == nullptr) {
    return HRESULT_FROM_WIN32(GetLastError());
  }

  using DllGetActivationFactory = HRESULT __stdcall(HSTRING, void** factory);
  auto call = reinterpret_cast<DllGetActivationFactory*>(GetProcAddress(library, "DllGetActivationFactory"));
  if (!call) {
    auto hr = HRESULT_FROM_WIN32(GetLastError());
    FreeLibrary(library);
    return hr;
  }

  Microsoft::WRL::ComPtr<IActivationFactory> activation_factory;
  auto hr = call(
    Microsoft::WRL::Wrappers::HStringReference(p_class_id, static_cast<unsigned int>(wcslen(p_class_id))).Get(),
    reinterpret_cast<void**>(activation_factory.GetAddressOf())
  );

  if (FAILED(hr)) {
    FreeLibrary(library);
    return hr;
  }

  return activation_factory->QueryInterface(iid, factory);
}

class WinMLLearningModel {
  friend class WinMLLearningModelSession;

 public:
  WinMLLearningModel(const wchar_t* model_path, size_t size) { ML_FAIL_FAST_IF(0 != Initialize(model_path, size)); }

  WinMLLearningModel(const char* bytes, size_t size) {
    ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*dont copy*/));
  }

 private:
  int32_t Initialize(const wchar_t* model_path, size_t size) {
    Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelStatics> learningModel;
    RETURN_HR_IF_FAILED(GetActivationFactory(
      RuntimeClass_Microsoft_AI_MachineLearning_LearningModel,
      ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics,
      &learningModel
    ));

    RETURN_HR_IF_FAILED(learningModel->LoadFromFilePath(
      Microsoft::WRL::Wrappers::HStringReference(model_path, static_cast<unsigned int>(size)).Get(),
      m_learning_model.GetAddressOf()
    ));
    return 0;
  }

  struct StoreCompleted : Microsoft::WRL::RuntimeClass<
                            Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
                            ABI::Windows::Foundation::IAsyncOperationCompletedHandler<uint32_t>> {
    HANDLE completed_event_;

    StoreCompleted() : completed_event_(CreateEvent(nullptr, true, false, nullptr)) {}

    ~StoreCompleted() { CloseHandle(completed_event_); }

    HRESULT STDMETHODCALLTYPE Invoke(
      ABI::Windows::Foundation::IAsyncOperation<uint32_t>* /*asyncInfo*/,
      ABI::Windows::Foundation::AsyncStatus /*status*/
    ) {
      SetEvent(completed_event_);
      return S_OK;
    }

    HRESULT Wait() {
      WaitForSingleObject(completed_event_, INFINITE);
      return S_OK;
    }
  };

  int32_t Initialize(const char* bytes, size_t size, bool with_copy = false) {
    auto hr = RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED);
    // https://docs.microsoft.com/en-us/windows/win32/api/roapi/nf-roapi-roinitialize#return-value
    // RPC_E_CHANGED_MODE indicates already initialized as multithreaded
    if (hr < 0 && hr != RPC_E_CHANGED_MODE) {
      return static_cast<int32_t>(hr);
    }

    Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReference> random_access_stream_ref;
    if (with_copy) {
      // Create in memory stream
      Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
      RETURN_HR_IF_FAILED(RoActivateInstance(
        Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream)
          .Get(),
        in_memory_random_access_stream_insp.GetAddressOf()
      ));

      // QI memory stream to output stream
      Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
      RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));

      // Create data writer factory
      Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriterFactory> activation_factory;
      RETURN_HR_IF_FAILED(RoGetActivationFactory(
        Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(),
        IID_PPV_ARGS(activation_factory.GetAddressOf())
      ));

      // Create data writer object based on the in memory stream
      Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriter> data_writer;
      RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter(output_stream.Get(), data_writer.GetAddressOf()));

      // Write the model to the data writer and thus to the stream
      RETURN_HR_IF_FAILED(
        data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char*>(bytes)))
      );

      // QI the in memory stream to a random access stream
      Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> random_access_stream;
      RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream));

      // Create a random access stream reference factory
      Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReferenceStatics>
        random_access_stream_ref_statics;
      RETURN_HR_IF_FAILED(RoGetActivationFactory(
        Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference)
          .Get(),
        IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf())
      ));

      // Create a random access stream reference from the random access stream view on top of
      // the in memory stream
      RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream(
        random_access_stream.Get(), random_access_stream_ref.GetAddressOf()
      ));

      Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperation<uint32_t>> async_operation;
      RETURN_HR_IF_FAILED(data_writer->StoreAsync(&async_operation));
      auto store_completed_handler = Microsoft::WRL::Make<StoreCompleted>();
      RETURN_HR_IF_FAILED(async_operation->put_Completed(store_completed_handler.Get()));
      RETURN_HR_IF_FAILED(store_completed_handler->Wait());

    } else {
      Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<char>> buffer;
      RETURN_HR_IF_FAILED(Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<char>>(&buffer, bytes, bytes + size));

      RETURN_HR_IF_FAILED(
        Microsoft::WRL::MakeAndInitialize<WinMLTest::BufferBackedRandomAccessStreamReference>(
          &random_access_stream_ref, buffer.Get()
        )
      );
    }

    // Create a learning model factory
    Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelStatics> learning_model;
    RETURN_HR_IF_FAILED(GetActivationFactory(
      RuntimeClass_Microsoft_AI_MachineLearning_LearningModel,
      ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics,
      &learning_model
    ));

    // Create a learning model from the factory with the random access stream reference that points
    // to the random access stream view on top of the in memory stream copy of the model
    RETURN_HR_IF_FAILED(
      learning_model->LoadFromStream(random_access_stream_ref.Get(), m_learning_model.GetAddressOf())
    );

    return 0;
  }

 private:
  Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModel> m_learning_model;
};

class WinMLLearningModelResults {
  friend class WinMLLearningModelSession;

 public:
  int32_t get_output(const wchar_t* feature_name, size_t feature_name_size, void** pp_buffer, size_t* p_capacity) {
    Microsoft::WRL::ComPtr<ABI::Windows::Foundation::Collections::IMapView<HSTRING, IInspectable*>> output_map;
    RETURN_HR_IF_FAILED(m_result->get_Outputs(&output_map));

    Microsoft::WRL::ComPtr<IInspectable> inspectable;
    RETURN_HR_IF_FAILED(output_map->Lookup(
      Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast<unsigned int>(feature_name_size)).Get(),
      inspectable.GetAddressOf()
    ));

    Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelFeatureValue> output_feature_value;
    RETURN_HR_IF_FAILED(inspectable.As(&output_feature_value));

    Microsoft::WRL::ComPtr<ITensorNative> native_tensor_float_feature_value;
    RETURN_HR_IF_FAILED(output_feature_value.As(&native_tensor_float_feature_value));

    uint32_t size;
    RETURN_HR_IF_FAILED(native_tensor_float_feature_value->GetBuffer(reinterpret_cast<BYTE**>(pp_buffer), &size));
    *p_capacity = size;

    return 0;
  }

 private:
  WinMLLearningModelResults(ABI::Microsoft::AI::MachineLearning::ILearningModelEvaluationResult* p_result) {
    m_result = p_result;
  }

 private:
  Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelEvaluationResult> m_result;
};

class WinMLLearningModelBinding {
  friend class WinMLLearningModelSession;

 public:
  WinMLLearningModelBinding(const WinMLLearningModelSession& session) { ML_FAIL_FAST_IF(0 != Initialize(session)); }

  template <typename T = float>
  int32_t bind(
    const wchar_t* feature_name,
    size_t feature_name_size,
    tensor_shape_type* p_shape,
    size_t shape_size,
    T* p_data,
    size_t data_size
  ) {
    using ITensor = typename Tensor<T>::Type;
    using ITensorFactory = typename TensorFactory<T>::Factory;

    Microsoft::WRL::ComPtr<ITensorFactory> tensor_factory;
    RETURN_HR_IF_FAILED(
      GetActivationFactory(TensorRuntimeClassID<T>::RuntimeClass_ID, TensorFactoryIID<T>::IID, &tensor_factory)
    );

    Microsoft::WRL::ComPtr<weak_single_threaded_iterable<int64_t>> input_shape_iterable;
    RETURN_HR_IF_FAILED(
      Microsoft::WRL::MakeAndInitialize<weak_single_threaded_iterable<int64_t>>(
        &input_shape_iterable, p_shape, p_shape + shape_size
      )
    );

    Microsoft::WRL::ComPtr<ITensor> tensor;
    RETURN_HR_IF_FAILED(tensor_factory->CreateFromArray(
      input_shape_iterable.Get(), static_cast<uint32_t>(data_size), p_data, tensor.GetAddressOf()
    ));

    Microsoft::WRL::ComPtr<IInspectable> inspectable_tensor;
    RETURN_HR_IF_FAILED(tensor.As(&inspectable_tensor));

    RETURN_HR_IF_FAILED(m_learning_model_binding->Bind(
      Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast<unsigned int>(feature_name_size)).Get(),
      inspectable_tensor.Get()
    ));
    return 0;
  }

  template <typename T = float>
  int32_t bind(
    const wchar_t* /*feature_name*/, size_t /*feature_name_size*/, tensor_shape_type* /*p_shape*/, size_t /*shape_size*/
  ) {
    return 0;
  }

  template <typename T = float>
  int32_t bind_as_reference(
    const wchar_t* feature_name,
    size_t feature_name_size,
    tensor_shape_type* p_shape,
    size_t shape_size,
    T* p_data,
    size_t data_size
  ) {
    using ITensor = typename Tensor<T>::Type;
    using ITensorFactory = typename TensorFactory2<T>::Factory;

    Microsoft::WRL::ComPtr<ITensorFactory> tensor_factory;
    RETURN_HR_IF_FAILED(
      GetActivationFactory(TensorRuntimeClassID<T>::RuntimeClass_ID, TensorFactory2IID<T>::IID, &tensor_factory)
    );

    Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<T>> buffer;
    RETURN_HR_IF_FAILED(
      Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(&buffer, p_data, p_data + data_size)
    );

    Microsoft::WRL::ComPtr<ITensor> tensor;
    RETURN_HR_IF_FAILED(
      tensor_factory->CreateFromBuffer(static_cast<uint32_t>(shape_size), p_shape, buffer.Get(), tensor.GetAddressOf())
    );

    Microsoft::WRL::ComPtr<IInspectable> inspectable_tensor;
    RETURN_HR_IF_FAILED(tensor.As(&inspectable_tensor));

    RETURN_HR_IF_FAILED(m_learning_model_binding->Bind(
      Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast<unsigned int>(feature_name_size)).Get(),
      inspectable_tensor.Get()
    ));
    return 0;
  }

  template <typename T = float>
  int32_t bind_as_references(
    const wchar_t* feature_name, size_t feature_name_size, T** p_data, size_t* data_sizes, size_t num_buffers
  ) {
    using ITensor = typename Tensor<T>::Type;
    using ITensorFactory = typename TensorFactory2<T>::Factory;

    std::vector<Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer>> vec_buffers(num_buffers);
    for (size_t i = 0; i < num_buffers; i++) {
      RETURN_HR_IF_FAILED(
        Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
          &vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i]
        )
      );
    }

    std::vector<ABI::Windows::Storage::Streams::IBuffer*> raw_buffers(num_buffers);
    std::transform(std::begin(vec_buffers), std::end(vec_buffers), std::begin(raw_buffers), [](auto buffer) {
      return buffer.Detach();
    });

    Microsoft::WRL::ComPtr<weak_single_threaded_iterable<ABI::Windows::Storage::Streams::IBuffer*>> buffers;
    RETURN_HR_IF_FAILED(
      Microsoft::WRL::MakeAndInitialize<weak_single_threaded_iterable<ABI::Windows::Storage::Streams::IBuffer*>>(
        &buffers, raw_buffers.data(), raw_buffers.data() + num_buffers
      )
    );

    Microsoft::WRL::ComPtr<IInspectable> inspectable_tensor;
    RETURN_HR_IF_FAILED(buffers.As(&inspectable_tensor));

    RETURN_HR_IF_FAILED(m_learning_model_binding->Bind(
      Microsoft::WRL::Wrappers::HStringReference(feature_name, static_cast<unsigned int>(feature_name_size)).Get(),
      inspectable_tensor.Get()
    ));
    return 0;
  }

 private:
  inline int32_t Initialize(const WinMLLearningModelSession& session);

 private:
  Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelBinding> m_learning_model_binding;
};

class WinMLLearningModelDevice {
  friend class WinMLLearningModelSession;

 public:
  WinMLLearningModelDevice()
    : WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_Default) {}

  WinMLLearningModelDevice(WinMLLearningModelDevice&& device)
    : m_learning_model_device(std::move(device.m_learning_model_device)) {}

  WinMLLearningModelDevice(const WinMLLearningModelDevice& device)
    : m_learning_model_device(device.m_learning_model_device) {}

  void operator=(const WinMLLearningModelDevice& device) { m_learning_model_device = device.m_learning_model_device; }

  WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind kind) {
    ML_FAIL_FAST_IF(0 != Initialize(kind));
  }

  WinMLLearningModelDevice(ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice* d3dDevice) {
    ML_FAIL_FAST_IF(0 != Initialize(d3dDevice));
  }

  WinMLLearningModelDevice(ID3D12CommandQueue* queue) { ML_FAIL_FAST_IF(0 != Initialize(queue)); }

  static WinMLLearningModelDevice create_cpu_device() {
    return WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_Cpu);
  }

  static WinMLLearningModelDevice create_directx_device() {
    return WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_DirectX);
  }

  static WinMLLearningModelDevice create_directx_high_power_device() {
    return WinMLLearningModelDevice(
      ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_DirectXHighPerformance
    );
  }

  static WinMLLearningModelDevice create_directx_min_power_device() {
    return WinMLLearningModelDevice(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind_DirectXMinPower);
  }

  static WinMLLearningModelDevice create_directx_device(
    ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice* d3dDevice
  ) {
    return WinMLLearningModelDevice(d3dDevice);
  }

  static WinMLLearningModelDevice create_directx_device(ID3D12CommandQueue* queue) {
    return WinMLLearningModelDevice(queue);
  }

 private:
  int32_t Initialize(ABI::Microsoft::AI::MachineLearning::LearningModelDeviceKind kind) {
    Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelDeviceFactory>
      learning_model_device_factory;
    RETURN_HR_IF_FAILED(GetActivationFactory(
      RuntimeClass_Microsoft_AI_MachineLearning_LearningModelDevice,
      ABI::Microsoft::AI::MachineLearning::IID_ILearningModelDeviceFactory,
      &learning_model_device_factory
    ));

    RETURN_HR_IF_FAILED(learning_model_device_factory->Create(kind, &m_learning_model_device));

    return 0;
  }

  int32_t Initialize(ABI::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice* d3dDevice) {
    Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelDeviceStatics>
      learning_model_device_factory;
    RETURN_HR_IF_FAILED(GetActivationFactory(
      RuntimeClass_Microsoft_AI_MachineLearning_LearningModelDevice,
      ABI::Microsoft::AI::MachineLearning::IID_ILearningModelDeviceStatics,
      &learning_model_device_factory
    ));

    RETURN_HR_IF_FAILED(learning_model_device_factory->CreateFromDirect3D11Device(d3dDevice, &m_learning_model_device));

    return 0;
  }

  int32_t Initialize(ID3D12CommandQueue* queue) {
    Microsoft::WRL::ComPtr<ILearningModelDeviceFactoryNative> learning_model_device_factory;
    RETURN_HR_IF_FAILED(GetActivationFactory(
      RuntimeClass_Microsoft_AI_MachineLearning_LearningModelDevice,
      __uuidof(ILearningModelDeviceFactoryNative),
      &learning_model_device_factory
    ));

    RETURN_HR_IF_FAILED(learning_model_device_factory->CreateFromD3D12CommandQueue(queue, &m_learning_model_device));

    return 0;
  }

 private:
  Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelDevice> m_learning_model_device;
};

class WinMLLearningModelSession {
  friend class WinMLLearningModelBinding;

 public:
  using Model = WinMLLearningModel;
  using Device = WinMLLearningModelDevice;

 public:
  WinMLLearningModelSession(const Model& model) { ML_FAIL_FAST_IF(0 != Initialize(model, Device())); }

  WinMLLearningModelSession(const Model& model, const Device& device) {
    ML_FAIL_FAST_IF(0 != Initialize(model, device));
  }

  WinMLLearningModelResults evaluate(WinMLLearningModelBinding& binding) {
    Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelEvaluationResult>
      m_learning_model_evaluation_result;

    FAIL_FAST_IF_HR_FAILED(m_learning_model_session->Evaluate(
      binding.m_learning_model_binding.Get(), nullptr, m_learning_model_evaluation_result.GetAddressOf()
    ));

    return WinMLLearningModelResults(m_learning_model_evaluation_result.Get());
  }

 private:
  int32_t Initialize(const Model& model, const Device& device) {
    // {d7d86c54-d03d-5ae3-a958-fe952b640620}
    static const GUID IID_ILearningModelSessionFactory = {
      0xd7d86c54, 0xd03d, 0x5ae3, {0xa9, 0x58, 0xfe, 0x95, 0x2b, 0x64, 0x06, 0x20}
    };

    Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelSessionFactory>
      m_learning_model_session_factory;
    RETURN_HR_IF_FAILED(GetActivationFactory(
      RuntimeClass_Microsoft_AI_MachineLearning_LearningModelSession,
      IID_ILearningModelSessionFactory,
      &m_learning_model_session_factory
    ));

    RETURN_HR_IF_FAILED(m_learning_model_session_factory->CreateFromModelOnDevice(
      model.m_learning_model.Get(), device.m_learning_model_device.Get(), m_learning_model_session.GetAddressOf()
    ));

    return 0;
  }

 private:
  Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelSession> m_learning_model_session;
};

inline int32_t WinMLLearningModelBinding::Initialize(const WinMLLearningModelSession& session) {
  // {ae2f1c97-2fd5-55b9-a05f-53b9dbb4f9e2}
  static const GUID IID_ILearningModelBindingFactory = {
    0xae2f1c97, 0x2fd5, 0x55b9, {0xa0, 0x5f, 0x53, 0xb9, 0xdb, 0xb4, 0xf9, 0xe2}
  };

  Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelBindingFactory>
    learning_model_binding_factory;

  RETURN_HR_IF_FAILED(GetActivationFactory(
    RuntimeClass_Microsoft_AI_MachineLearning_LearningModelBinding,
    IID_ILearningModelBindingFactory,
    &learning_model_binding_factory
  ));

  RETURN_HR_IF_FAILED(learning_model_binding_factory->CreateFromSession(
    session.m_learning_model_session.Get(), m_learning_model_binding.GetAddressOf()
  ));

  return 0;
}

}  // namespace Details
}  // namespace MachineLearning
}  // namespace AI
}  // namespace Microsoft

#endif  // WINML_H_
